-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.compile for Whisper #30949
Add torch.compile for Whisper #30949
Conversation
if ( | ||
is_cross_attention | ||
and past_key_value is not None | ||
and past_key_value[0].shape[2] == key_value_states.shape[1] | ||
): | ||
# reuse k,v, cross_attentions | ||
key_states = past_key_value[0] | ||
value_states = past_key_value[1] | ||
elif is_cross_attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here according to my understanding, key_states and value_states are once and for all computed based on encoder hidden states when we are doing cross-attention, so we can cache them ideally, but I think these are a little bit different from the existing caches because we don't need to update the cache in every generation step, we just have to do it once in the first step, so we could either try creating another Cache class and in that cases we need to pass in two caches(one for self attention and one for cross attention), even if we can create a cache class to wrap both, we still need to modify the current get_cache
logic because it definitely will need more parameters to initiate the cache, and I don't know if the use case is general enough to create this new cache class. Or we can just initiate and update the cross attention kv cache within every layer like recurrent gemma, but this will require manually reset the cache between generations because the current logic in generation seems not considering the case when there is an inherent cache possessed by the model. I am personally for the latter solution, however I think both will need some changes in the generation utils file, and not sure which way might make dynamo or cudagraphs unhappy, what do you think is best @ArthurZucker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your understanding is correct: we compute the k/v states for the cross-attention once in the first forward pass, and then save them to cache. We would indeed want to cache the k/v states to avoid re-computing them at every step.
What do you think about the proposed design in this PR? #28931 (comment) Similar to your solution, it uses two cache's: one for the self-attention, and one for the cross-attention. Note that the cache specifics are out-of-date, given the recent changes to the cache API, but the high-level design remains valid!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it would be great if two StaticCache would work, but the issue is in that case we can not tell whether we are in the first generation step or not because the shape of cache now is always fixed to hold the maximum generation size, and use get_seq_length
(to see if we have processed any tokens yet during tracing) in branch condition will indeed cause graph breaks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my current solution is use two StaticCache and add another flag in attention layer to mark whether we are doing the first generation step, this will cause a recompile in the second step just like in llama and mistral but should work fine for the subsequent steps, but it's still nasty because it breaks the current stand-alone cache design, maybe I should indeed create another new OneShotCache
class which hides the flag inside? cc @ArthurZucker
@gante @sanchit-gandhi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current solution utilizes a new OneShotCache
, and a tuple of two caches is expected for conditional generation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
if self.config.is_encoder_decoder: | ||
# manually set another cache for cross attention | ||
encoder_outputs = model_kwargs["encoder_outputs"][0] | ||
model_kwargs["past_key_values"] = ( | ||
model_kwargs["past_key_values"], | ||
self._get_cache("one_shot", encoder_outputs.shape[0], encoder_outputs.shape[1], '_cross_attn_cache') | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there is a better way to do this, because in the encoder-decoder scenario we need a tuple of two caches here according to the current design, but this seems hardcode and easy to break
if is_cross_attention and ((isinstance(past_key_value, DynamicCache) and self.layer_idx < len(past_key_value.key_cache)) | ||
or isinstance(past_key_value, OneShotStaticCache) and past_key_value.query_cache_filled_status(self.layer_idx)): | ||
key_states = past_key_value.key_cache[self.layer_idx] | ||
value_states = past_key_value.value_cache[self.layer_idx] | ||
need_update_cache = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we can always use OneShotStaticCache
for kv cache here for simplicity, beccause using Dynamic Cache
won't give us memory benefits on cross-atten kv cache
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
this PR adds torch.compile support for whisper model which is encoder-decoder architecture